[RFC][SPIR-V] Add intrinsics to convert to/from ap.float#164252
[RFC][SPIR-V] Add intrinsics to convert to/from ap.float#164252
Conversation
The intrinsic performs conversions between values whose interpretation differs from their representation in LLVM IR. The intrinsic is overloaded on both its return type and first argument. Metadata operands describe how the raw bits should be interpreted before and after the conversion. Current patch adds only lowering to SPIR-V. Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
|
@llvm/pr-subscribers-llvm-support @llvm/pr-subscribers-backend-spir-v Author: Dmitry Sidorov (MrSidims) ChangesThe intrinsic performs conversions between values whose interpretation differs from their representation in LLVM IR. The intrinsic is overloaded on both its return type and first argument. Metadata operands describe how the raw bits should be interpreted before and after the conversion. Current patch adds only lowering to SPIR-V. Addresses https://discourse.llvm.org/t/rfc-spir-v-way-to-represent-float8-in-llvm-ir/87758/10 Patch is 53.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164252.diff 16 Files Affected:
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 033910121a54f..fd21ecffa0aed 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -21406,6 +21406,69 @@ environment <floatenv>` *except* for the rounding mode.
This intrinsic is not supported on all targets. Some targets may not support
all rounding modes.
+'``llvm.arbitrary.fp.convert``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+::
+
+ declare <type> @llvm.arbitrary.fp.convert(
+ <type> <value>, metadata <result interpretation>,
+ metadata <input interpretation>, metadata <rounding mode>,
+ i32 <saturation>)
+
+Overview:
+"""""""""
+
+The ``llvm.arbitrary.fp.convert`` intrinsic performs conversions
+between values whose interpretation differs from their representation
+in LLVM IR. The intrinsic is overloaded on both its return type and first
+argument. Metadata operands describe how the raw bits should be interpreted
+before and after the conversion.
+
+Arguments:
+""""""""""
+
+``value``
+ The value to convert. Its interpretation is described by ``input
+ interpretation``.
+
+``result interpretation``
+ A metadata string that describes the type of the result. The string
+ can be ``"none"`` (no conversion needed), ``"signed"`` or ``"unsigned"`` (for
+ integer types), or any target-specific string for floating-point formats.
+ For example ``"spv.E4M3EXT"`` and ``"spv.E5M2EXT"`` stand for FP8 SPIR-V formats.
+ Using ``"none"`` indicates the converted bits already have the desired LLVM IR type.
+
+``input interpretation``
+ Mirrors ``result interpretation`` but applies to the first argument. The
+ interpretation is target-specific and describes how to interpret the raw bits
+ of the input value.
+
+``rounding mode``
+ A metadata string. The permitted strings match those accepted by
+ :ref:`llvm.fptrunc.round <int_fptrunc_round>` (for example,
+ ``"round.tonearest"`` or ``"round.towardzero"``). The string ``"none"`` may be
+ used to indicate that the default rounding behaviour of the conversion should
+ be used.
+
+``saturation``
+ An integer constant (0 or 1) indicating whether saturation should be applied
+ to the conversion. When set to 1, values outside the representable range of
+ the result type are clamped to the minimum or maximum representable value
+ instead of wrapping. When set to 0, no saturation is applied.
+
+Semantics:
+""""""""""
+
+The intrinsic interprets the first argument according to ``input
+interpretation``, applies the requested rounding mode and saturation behavior,
+and produces a value whose type is described by ``result interpretation``.
+When saturation is enabled, values that exceed the representable range of the target
+format are clamped to the minimum or maximum representable value of that format.
+
Convergence Intrinsics
----------------------
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 12d1c2528f977..b0c8ea1e47fc7 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1091,6 +1091,14 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
def int_fptrunc_round : DefaultAttrsIntrinsic<[ llvm_anyfloat_ty ],
[ llvm_anyfloat_ty, llvm_metadata_ty ]>;
+ // Convert between arbitrary interpreted floating-point and integer values.
+ def int_arbitrary_fp_convert
+ : DefaultAttrsIntrinsic<
+ [ llvm_any_ty ],
+ [ llvm_any_ty, llvm_metadata_ty, llvm_metadata_ty,
+ llvm_metadata_ty, llvm_i32_ty ],
+ [ IntrNoMem, IntrSpeculatable ]>;
+
def int_canonicalize : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>],
[IntrNoMem]>;
// Arithmetic fence intrinsic.
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 884c3f1692e94..f0a6c7082985e 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2842,7 +2842,10 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (!MDN) {
if (auto *ConstMD = dyn_cast<ConstantAsMetadata>(MD))
MDN = MDNode::get(MF->getFunction().getContext(), ConstMD);
- else // This was probably an MDString.
+ else if (auto *MDS = dyn_cast<MDString>(MD)) {
+ Metadata *Ops[] = {MDS};
+ MDN = MDNode::get(MF->getFunction().getContext(), Ops);
+ } else
return false;
}
MIB.addMetadata(MDN);
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 03da1547b652f..58b80191625c5 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -80,6 +80,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/EHPersonalities.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/FPEnv.h"
#include "llvm/IR/GCStrategy.h"
#include "llvm/IR/GlobalAlias.h"
#include "llvm/IR/GlobalValue.h"
@@ -5848,6 +5849,52 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
"unsupported rounding mode argument", Call);
break;
}
+ case Intrinsic::arbitrary_fp_convert: {
+ auto *ResultMAV = dyn_cast<MetadataAsValue>(Call.getArgOperand(1));
+ Check(ResultMAV, "missing result interpretation metadata operand", Call);
+ auto *ResultStr = dyn_cast<MDString>(ResultMAV->getMetadata());
+ Check(ResultStr, "result interpretation metadata operand must be a string",
+ Call);
+ StringRef ResultInterp = ResultStr->getString();
+
+ auto *InputMAV = dyn_cast<MetadataAsValue>(Call.getArgOperand(2));
+ Check(InputMAV, "missing input interpretation metadata operand", Call);
+ auto *InputStr = dyn_cast<MDString>(InputMAV->getMetadata());
+ Check(InputStr, "input interpretation metadata operand must be a string",
+ Call);
+ StringRef InputInterp = InputStr->getString();
+
+ auto *RoundingMAV = dyn_cast<MetadataAsValue>(Call.getArgOperand(3));
+ Check(RoundingMAV, "missing rounding mode metadata operand", Call);
+ auto *RoundingStr = dyn_cast<MDString>(RoundingMAV->getMetadata());
+ Check(RoundingStr, "rounding mode metadata operand must be a string",
+ Call);
+ StringRef RoundingInterp = RoundingStr->getString();
+
+ // Check that interpretation strings are not empty. The actual interpretation
+ // values are target-specific and not validated here.
+ Check(!ResultInterp.empty(),
+ "result interpretation metadata string must not be empty", Call);
+ Check(!InputInterp.empty(),
+ "input interpretation metadata string must not be empty", Call);
+
+ if (RoundingInterp != "none") {
+ std::optional<RoundingMode> RM =
+ convertStrToRoundingMode(RoundingInterp);
+ Check(RM && *RM != RoundingMode::Dynamic,
+ "unsupported rounding mode argument", Call);
+ }
+
+ // Check saturation parameter (must be 0 or 1)
+ auto *SaturationOp = dyn_cast<ConstantInt>(Call.getArgOperand(4));
+ Check(SaturationOp, "saturation operand must be a constant integer", Call);
+ if (SaturationOp) {
+ uint64_t SatVal = SaturationOp->getZExtValue();
+ Check(SatVal == 0 || SatVal == 1,
+ "saturation operand must be 0 or 1", Call);
+ }
+ break;
+ }
#define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...) case Intrinsic::VPID:
#include "llvm/IR/VPIntrinsics.def"
#undef BEGIN_REGISTER_VP_INTRINSIC
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 96f5dee21bc2a..fe6c5783f61ed 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -149,6 +149,7 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
{"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16},
+ {"SPV_EXT_float8", SPIRV::Extension::Extension::SPV_EXT_float8},
{"SPV_EXT_relaxed_printf_string_address_space",
SPIRV::Extension::Extension::
SPV_EXT_relaxed_printf_string_address_space},
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6fd1c7ed78c06..3d13e375c06e4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -215,6 +215,43 @@ SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
});
}
+SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFloatWithEncoding(
+ uint32_t Width, MachineIRBuilder &MIRBuilder,
+ SPIRV::FPEncoding::FPEncoding FPEncode) {
+ auto Key = std::make_pair(Width, static_cast<unsigned>(FPEncode));
+ if (SPIRVType *Existing = FloatTypesWithEncoding.lookup(Key)) {
+ // Check if the existing type is from the current function
+ const MachineFunction *TypeMF = Existing->getParent()->getParent();
+ if (TypeMF == &MIRBuilder.getMF())
+ return Existing;
+ // Type is from a different function, need to create a new one for current function
+ }
+
+ SPIRVType *SpvType = getOpTypeFloat(Width, MIRBuilder, FPEncode);
+ LLVMContext &Ctx = MIRBuilder.getMF().getFunction().getContext();
+ Type *LLVMTy = nullptr;
+ switch (Width) {
+ case 8:
+ LLVMTy = Type::getInt8Ty(Ctx);
+ break;
+ case 16:
+ LLVMTy = Type::getHalfTy(Ctx);
+ break;
+ case 32:
+ LLVMTy = Type::getFloatTy(Ctx);
+ break;
+ case 64:
+ LLVMTy = Type::getDoubleTy(Ctx);
+ break;
+ default:
+ report_fatal_error("unsupported floating-point width for SPIR-V encoding");
+ }
+
+ SpvType = finishCreatingSPIRVType(LLVMTy, SpvType);
+ FloatTypesWithEncoding.try_emplace(Key, SpvType);
+ return SpvType;
+}
+
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a648defa0a888..47353fee10065 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -40,6 +40,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
+ DenseMap<std::pair<unsigned, unsigned>, SPIRVType *>
+ FloatTypesWithEncoding;
+
// map a Function to its definition (as a machine instruction operand)
DenseMap<const Function *, const MachineOperand *> FunctionToInstr;
DenseMap<const MachineInstr *, const Function *> FunctionToInstrRev;
@@ -413,6 +416,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// Return the number of bits SPIR-V pointers and size_t variables require.
unsigned getPointerSize() const { return PointerSize; }
+ SPIRVType *getOrCreateOpTypeFloatWithEncoding(
+ uint32_t Width, MachineIRBuilder &MIRBuilder,
+ SPIRV::FPEncoding::FPEncoding FPEncode);
+
// Returns true if two types are defined and are compatible in a sense of
// OpBitcast instruction
bool isBitcastCompatible(const SPIRVType *Type1,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index a0cff4d82b500..3963d126d4f73 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -22,6 +22,8 @@
#include "SPIRVUtils.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/IR/FPEnv.h"
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/InstructionSelector.h"
@@ -195,6 +197,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectFloatDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectArbitraryFPConvert(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectOverflowArith(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, unsigned Opcode) const;
bool selectDebugTrap(Register ResVReg, const SPIRVType *ResType,
@@ -2101,6 +2106,439 @@ bool SPIRVInstructionSelector::selectFloatDot(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+static std::optional<SPIRV::FPEncoding::FPEncoding>
+getFloat8EncodingFromString(StringRef Interpretation) {
+ return StringSwitch<std::optional<SPIRV::FPEncoding::FPEncoding>>(Interpretation)
+ .Case("spv.E4M3EXT", SPIRV::FPEncoding::Float8E4M3EXT)
+ .Case("spv.E5M2EXT", SPIRV::FPEncoding::Float8E5M2EXT)
+ .Default(std::nullopt);
+}
+
+// Enum to classify interpretation types
+enum class InterpretationType {
+ None,
+ Signed,
+ Unsigned,
+ Float8E4M3,
+ Float8E5M2,
+ Unknown
+};
+
+static InterpretationType classifyInterpretation(StringRef Interp) {
+ return StringSwitch<InterpretationType>(Interp)
+ .Case("none", InterpretationType::None)
+ .Case("signed", InterpretationType::Signed)
+ .Case("unsigned", InterpretationType::Unsigned)
+ .Case("spv.E4M3EXT", InterpretationType::Float8E4M3)
+ .Case("spv.E5M2EXT", InterpretationType::Float8E5M2)
+ .Default(InterpretationType::Unknown);
+}
+
+static std::optional<SPIRV::FPEncoding::FPEncoding>
+interpretationToFP8Encoding(InterpretationType Type) {
+ switch (Type) {
+ case InterpretationType::Float8E4M3:
+ return SPIRV::FPEncoding::Float8E4M3EXT;
+ case InterpretationType::Float8E5M2:
+ return SPIRV::FPEncoding::Float8E5M2EXT;
+ default:
+ return std::nullopt;
+ }
+}
+
+// Helper struct to hold parsed intrinsic parameters
+struct ArbitraryConvertParams {
+ Register SrcReg;
+ StringRef ResultInterp;
+ StringRef InputInterp;
+ StringRef RoundingInterp;
+ bool UseSaturation;
+
+ InterpretationType SrcType;
+ InterpretationType DstType;
+
+ static std::optional<ArbitraryConvertParams>
+ parse(const MachineInstr &I, const MachineRegisterInfo *MRI) {
+ unsigned IntrinsicIdx = I.getNumDefs();
+ if (IntrinsicIdx >= I.getNumOperands())
+ return std::nullopt;
+
+ unsigned ValueIdx = IntrinsicIdx + 1;
+ if (ValueIdx + 4 >= I.getNumOperands())
+ return std::nullopt;
+
+ const MachineOperand &ValueOp = I.getOperand(ValueIdx);
+ if (!ValueOp.isReg())
+ return std::nullopt;
+
+ auto GetStringFromMD = [&](unsigned OperandIdx) -> std::optional<StringRef> {
+ const MachineOperand &Op = I.getOperand(OperandIdx);
+ if (!Op.isMetadata())
+ return std::nullopt;
+ const MDNode *MD = Op.getMetadata();
+ if (!MD || MD->getNumOperands() != 1)
+ return std::nullopt;
+ if (auto *Str = dyn_cast<MDString>(MD->getOperand(0)))
+ return Str->getString();
+ return std::nullopt;
+ };
+
+ std::optional<StringRef> ResultInterp = GetStringFromMD(ValueIdx + 1);
+ std::optional<StringRef> InputInterp = GetStringFromMD(ValueIdx + 2);
+ std::optional<StringRef> RoundingInterp = GetStringFromMD(ValueIdx + 3);
+ if (!ResultInterp || !InputInterp || !RoundingInterp)
+ return std::nullopt;
+
+ // Get saturation parameter
+ const MachineOperand &SaturationOp = I.getOperand(ValueIdx + 4);
+ int64_t SaturationValue;
+ if (SaturationOp.isImm()) {
+ SaturationValue = SaturationOp.getImm();
+ } else if (SaturationOp.isReg()) {
+ SaturationValue = foldImm(SaturationOp, MRI);
+ } else {
+ return std::nullopt;
+ }
+
+ ArbitraryConvertParams Params;
+ Params.SrcReg = ValueOp.getReg();
+ Params.ResultInterp = *ResultInterp;
+ Params.InputInterp = *InputInterp;
+ Params.RoundingInterp = *RoundingInterp;
+ Params.UseSaturation = SaturationValue != 0;
+
+ Params.SrcType = classifyInterpretation(Params.InputInterp);
+ Params.DstType = classifyInterpretation(Params.ResultInterp);
+
+ return Params;
+ }
+
+ // Helper methods for type checking
+ bool isSrcFP8() const {
+ return SrcType == InterpretationType::Float8E4M3 ||
+ SrcType == InterpretationType::Float8E5M2;
+ }
+
+ bool isDstFP8() const {
+ return DstType == InterpretationType::Float8E4M3 ||
+ DstType == InterpretationType::Float8E5M2;
+ }
+
+ std::optional<SPIRV::FPEncoding::FPEncoding> getSrcFP8Encoding() const {
+ return interpretationToFP8Encoding(SrcType);
+ }
+
+ std::optional<SPIRV::FPEncoding::FPEncoding> getDstFP8Encoding() const {
+ return interpretationToFP8Encoding(DstType);
+ }
+};
+
+// Helper function to create Float8 type (scalar or vector)
+static SPIRVType *createFloat8Type(unsigned ComponentCount,
+ SPIRV::FPEncoding::FPEncoding Encoding,
+ MachineIRBuilder &MIRBuilder,
+ SPIRVGlobalRegistry &GR) {
+ SPIRVType *Float8ScalarType =
+ GR.getOrCreateOpTypeFloatWithEncoding(8, MIRBuilder, Encoding);
+ if (ComponentCount > 1)
+ return GR.getOrCreateSPIRVVectorType(Float8ScalarType, ComponentCount,
+ MIRBuilder, false);
+ return Float8ScalarType;
+}
+
+// Helper function to build bitcast if type conversion is needed
+static std::optional<Register>
+buildBitcastIfNeeded(Register SrcReg, SPIRVType *SrcType, SPIRVType *TargetType,
+ MachineInstr &I, const TargetInstrInfo &TII,
+ const TargetRegisterInfo &TRI,
+ const RegisterBankInfo &RBI, MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR) {
+ if (SrcType == TargetType)
+ return SrcReg;
+
+ Register CastReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+ GR.assignSPIRVTypeToVReg(TargetType, CastReg, *I.getMF());
+ auto BitcastMIB =
+ BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpBitcast))
+ .addDef(CastReg)
+ .addUse(GR.getSPIRVTypeID(TargetType))
+ .addUse(SrcReg);
+ if (!BitcastMIB.constrainAllUses(TII, TRI, RBI))
+ return std::nullopt;
+ return CastReg;
+}
+
+bool SPIRVInstructionSelector::selectArbitraryFPConvert(
+ Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+ // Parse intrinsic parameters
+ std::optional<ArbitraryConvertParams> MaybeParams =
+ ArbitraryConvertParams::parse(I, MRI);
+ if (!MaybeParams)
+ return false;
+
+ const ArbitraryConvertParams &Params = *MaybeParams;
+ Register SrcReg = Params.SrcReg;
+ SPIRVType *SrcType = GR.getSPIRVTypeForVReg(SrcReg);
+ LLT SrcLLT = MRI->getType(SrcReg);
+
+ // Parse and validate rounding mode
+ bool RoundingNone = Params.RoundingInterp == "none";
+ std::optional<RoundingMode> RM;
+ if (!RoundingNone) {
+ RM = convertStrToRoundingMode(Params.RoundingInterp);
+ if (!RM || *RM == RoundingMode::Dynamic ||
+ *RM == RoundingMode::NearestTiesToAway)
+ return false;
+ }
+
+ auto GetComponentInfo = [&](const SPIRVType *Type)
+ -> std::pair<const SPIRVType *, unsigned> {
+ if (!Type)
+ return {nullptr, 0};
+ return {GR.getScalarOrVectorComponentType(Type),
+ GR.getScalarOrVectorComponentCount(Type)};
+ };
+
+ MachineIRBuilder MIRBuilder(I);
+
+ // Conversion path 1: FP8 -> Float (e.g., spv.E4M3EXT -> none)
+ if (Params.DstType == InterpretationType::None && Params.isSrcFP8()) {
+ if (RM)
+ return false;
+
+ auto [ResScalarType, ComponentCount] = GetComponentInfo(ResType);
+ if (!ResScalarType || ResScalarType->getOpcode() != SPIRV::OpTypeFloat)
+ return false;
+
+ unsigned Width = ResScalarType->getOperand(1).getImm();
+ if (Width != 16 && Width != 32 && Width != 64)
+ return false;
+
+ unsigned SrcComponentCount = 0;
+ if (SrcType) {
+ SrcComponentCount = GR.getScalarOrVectorComponentCount(SrcType);
+ } else {
+ if (!SrcLLT.isValid())
+ return false;
+ SrcComponentCount = SrcLLT.isVector() ? SrcLLT.getNumElements() : 1;
+ }
+ if (SrcComponentCount != ComponentCount)
+ return false;
+
+ SPIRVType *Float8Type =
+ createFloat8Type(ComponentCount, *Params.getSrcFP8Encoding(), MIRBuilder, GR);
+
+ std::optio...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
arsenm
left a comment
There was a problem hiding this comment.
This also needs some level of legalization expansion
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
s-perron
left a comment
There was a problem hiding this comment.
LGTM. I don't have strong feelings on the name.
|
Would be good to cross post this to the RFC thread. |
efriedma-quic
left a comment
There was a problem hiding this comment.
My quick impressions:
- For non-fp8 types, this adds some new functionality for int-to-fp conversions: an explicit static rounding mode. CC @spavloff . The rest of the conversions I think are existing conversions.
- I think, for clarity, I'd prefer to have separate intrinsics for int-to-fp, fp-to-int, and fp-to-fp; they have significantly different semantics, which is hard to describe when they're all mixed together. (What does it mean for an fp type to saturate? Are the integers signed? Can you do int-to-int conversions?)
This is what is not intendent to be introduced, as well as "standard" float to "standard" float conversions. Originally what I was thinking of is something like what @arsenm suggested in #164252 (comment) , but then I came into question: how to represent mini-float to mini-float conversion. May be it's not that odd to have mini-floats as both input and output of [to/from] conversions intrinsics. |
You could just use two instructions. Extending to f32 is lossless, so you can extend to f32, then truncate. It might be a little harder to optimize well, but probably most consumers of SPIRV would end up lowering the conversion to that, anyway. |
|
Thanks, I going through the changes. I was waiting for approval to publish several SPIR-V extensions, and now with intel/llvm#20467 public can go through with some concrete examples, including mini-float to mini-float conversions. |
Signed-off-by: Dmitry Sidorov <dmitrii.s.sidorov@gmail.com>
Keenuts
left a comment
There was a problem hiding this comment.
Thanks for the changes, LGTM on my end
nikic
left a comment
There was a problem hiding this comment.
Generally looks reasonable to me.
There was a problem hiding this comment.
I'd add a note to both intrinsics that the supported conversions are target dependent. (These aren't going to get generic legalization support, right?)
There was a problem hiding this comment.
These should get generic legalization support
There was a problem hiding this comment.
I think this note just makes things worse. It's stating poor QoI as a goal.
Many intrinsics are broken for different type combinations on different targets, but this isn't a desirable state. There isn't anything target dependent required to legalize these
There was a problem hiding this comment.
My expectation for these intrinsics is that they indeed do not have generic legalization support. They're just a generic spelling for target-specific conversions.
These intrinsics, if they were legalized, should use libcall legalizations, but there are no libcalls for these and I don't expect that they are going to be introduced, so I don't think legalization support makes a lot of sense.
Having someone implement inline expansions for all the type combinations without actually having a use case for it sounds like a massive waste of time.
There was a problem hiding this comment.
What I had in mind is that for FP4 conversions legalization is quite trivial in case if add a look-up table, then FP4 value becomes just an index. After double checking - FP6 case seem to be also trivial as it's indeed just bit shuffling + rounding.
There are FP8 + FN/FNUZ/E8M0FNU case, where generic lowering stops being “just shuffle + one rounding bit” and becomes “shuffle + full special-case semantics + careful NaN/zero rules.”. I lean towards agreeing, that generic legalization is feasible, yet not 100% sure if there will be interest for the intrinsics outside of the use cases, when a hardware supports the conversions in a performant way.
There was a problem hiding this comment.
Do you want me to remove the note and implement generic legalization in this PR? I'm asking as I'm not 100% sure if I'll be able to work on this very PR past January 14th due to a job switch.
There was a problem hiding this comment.
Removed the note. Lets actually land some implementation and then see how it goes.
There was a problem hiding this comment.
Leaving the legalization out of this PR is fine for now, but I do think it should be implemented. That significantly increases the utility of the intrinsics.
There was a problem hiding this comment.
e.g., for single source languages it's very useful to have common operations you can rely on for the host and device code. Restrictions to device-only are limiting
Reuse llvm's nan def no more undef target-specific statement
nikic
left a comment
There was a problem hiding this comment.
LGTM from my side, but please wait for an additional approval.
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
Signed-off-by: Sidorov, Dmitry <dmitry.sidorov@intel.com>
|
@arsenm do you have concerns merging the PR? |
The patch adds two intrinsics: llvm.convert.to.arbitrary.fp and llvm.convert.from.arbitrary.fp. The intrinsics perform conversions between values whose interpretation differs from their representation in LLVM IR. The intrinsics are overloaded on both its return type and first argument. Metadata operands describe how the raw bits should be interpreted before and after the conversion. Typical use case is to convert IEEE-754 floating point types to FP8/FP4 and backwards for ML applications. Addresses https://discourse.llvm.org/t/rfc-spir-v-way-to-represent-float8-in-llvm-ir/87758/10
The patch adds two intrinsics: llvm.convert.to.arbitrary.fp and llvm.convert.from.arbitrary.fp. The intrinsics perform conversions between values whose interpretation differs from their representation in LLVM IR. The intrinsics are overloaded on both its return type and first argument. Metadata operands describe how the raw bits should be interpreted before and after the conversion. Typical use case is to convert IEEE-754 floating point types to FP8/FP4 and backwards for ML applications. Addresses https://discourse.llvm.org/t/rfc-spir-v-way-to-represent-float8-in-llvm-ir/87758/10
The patch adds two intrinsics: llvm.convert.to.arbitrary.fp and llvm.convert.from.arbitrary.fp.
The intrinsics perform conversions between values whose interpretation differs from their representation in LLVM IR. The intrinsics are overloaded on both its return type and first argument. Metadata operands describe how the raw bits should be interpreted before and after the conversion.
Typical use case is to convert IEEE-754 floating point types to FP8/FP4 and backwards for ML applications.
Addresses https://discourse.llvm.org/t/rfc-spir-v-way-to-represent-float8-in-llvm-ir/87758/10